DEBUG_MODE = False
USE_CUDA = True
CUDA_DEVICE_NUM = 0


##########################################################################################
# Path Config

import os
import sys
import torch

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, "..")  # for problem_def
sys.path.insert(0, "../..")  # for utils

##########################################################################################
import logging

from utils import create_logger
from PFSPTrainer import PFSPTrainer as Trainer


##########################################################################################
pomo_size = 128
n_mc = 5

env_params = {
    'job_cnt': 20,
    'mc_cnt':n_mc, 
    'pomo_size':pomo_size,
    'mode':'uniform'
}

eval_env_params = {
    'job_cnt': 20,
    'mc_cnt':n_mc,
    'pomo_size':pomo_size,
    'mode':'uniform'
}

model_params = {
    'embedding_dim': 128,
    'sqrt_embedding_dim': 128**(1/2),
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'sqrt_qkv_dim': 16**(1/2),
    'head_num': 8,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    'eval_type': 'softmax',
    'one_hot_seed_cnt': n_mc,  # must be >= node_cnt
    'latent_cont_size':4,
    'latent_disc_size':12,
    'temperature': 1,
    'pomo_size':pomo_size,
    'mc_cnt':n_mc,
    'eval_pomo_size':pomo_size
}

optimizer_params = {
    'optimizer': {
        'lr': 1*1e-4,
        'weight_decay': 1e-6
    },
    'scheduler': {
        'milestones': [900, 950],
        'gamma': 0.1
    }
}

trainer_params = {
    'use_cuda': USE_CUDA,
    'cuda_device_num': CUDA_DEVICE_NUM,
    'epochs': 1000,
    'train_episodes': 100000,
    'train_batch_size': 100,
    'eval_episode': 1000,
    'eval_batch_size':100,
    'accumulation_step':2,
    'model_load':{
        'enable': False,
        'load_model_only':True,
        'path':'./result/saved_PFSP20x5',
        'epoch':1000
    }
}

logger_params = {
    'log_file': {
        'desc': 'PFSP_train',
        'filename': 'run_log'
    }
}

##########################################################################################
# main

def main():
    if DEBUG_MODE:
        _set_debug_mode()

    create_logger(**logger_params)
    _print_config()

    trainer = Trainer(env_params=env_params,
                      model_params=model_params,
                      optimizer_params=optimizer_params,
                      trainer_params=trainer_params,
                      eval_env_param = eval_env_params)

    #copy_all_src(trainer.result_folder)

    trainer.run()


def _set_debug_mode():

    global trainer_params
    trainer_params['epochs'] = 2
    trainer_params['train_episodes'] = 4
    trainer_params['train_batch_size'] = 4
    trainer_params['eval_episode']=4
    trainer_params['eval_batch_size']=4


def _print_config():
    logger = logging.getLogger('root')
    logger.info('DEBUG_MODE: {}'.format(DEBUG_MODE))
    logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM))
    [logger.info(g_key + "{}".format(globals()[g_key])) for g_key in globals().keys() if g_key.endswith('params')]


##########################################################################################

if __name__ == "__main__":
    torch.cuda.empty_cache()
    main()